#!/usr/bin/env python3
"""
Create a results table from bootstrap results files.
Simple version that reads the new results structure.
"""

import pandas as pd
import os
import glob
import argparse
import numpy as np

# =============================================================================
# HYPERPARAMETERS - CONFIGURE THESE
# =============================================================================

# Model name mapping
MODEL_NAMES = {
    'bayesian_elo': 'BBT',
    'bayesian_elo_noise': 'BBTQ (ours)', 
    'google_elo': 'BTQ'
}

# Dataset name mapping (project_dataset as key)
DATASET_NAMES = {
    'clic2024_2AFC_google_elo': '2AFC',
    'clic2024_2AFC_filtered_google_elo': '2AFC Filtered',
    'clic2024_2AFC_unfiltered_google_elo': '2AFC Unfiltered',
    'hific_userstudy_google_elo': 'HiFiC',
    'WD_answers': 'Wasserstein Distance',
    'conha_answers': 'Conha',
    'prolific_feedback_comparisons': 'Prolific',
    'mtbench_human_judgments': 'MTBench',
}

# Datasets to include in table (in order) - tuples of (project, dataset)
DATASETS_TO_INCLUDE = [
    ('prolific', 'feedback_comparisons'),
    ('mtbench', 'human_judgments'),
    ('WD', 'answers'),
    ('hific', 'userstudy_google_elo'),
    ('conha', 'answers'),
    ('clic2024', '2AFC_google_elo'),
    ('clic2024', '2AFC_filtered_google_elo'), 
    ('clic2024', '2AFC_unfiltered_google_elo'),
]

# Models to include (in order)
MODELS_TO_INCLUDE = ['google_elo', 'bayesian_elo', 'bayesian_elo_noise']

# Available metrics (choose one)
AVAILABLE_METRICS = ['Top1 agreement', 'Spearman correlation', "Kendall's tau", 'Time']
SELECTED_METRIC = "Kendall's tau"  # Change this to select the metric
SELECTED_METRIC = "Top1 agreement"  # Change this to select the metric

# Include bootstrapped standard error
INCLUDE_STD_ERROR = False  # Set to True to include standard error with ±

# =============================================================================
# MAIN FUNCTIONS
# =============================================================================

def load_results(results_dir, metric):
    """Load results from the new directory structure."""
    results = {}
    
    for project, dataset in DATASETS_TO_INCLUDE:
        dataset_key = f"{project}_{dataset}"
        
        for model in MODELS_TO_INCLUDE:
            # Construct path: results/{project}/{dataset}/{model}/all/all/bootstrap_results.csv
            file_path = os.path.join(results_dir, project, dataset, model, 'all', 'all', 'bootstrap_results.csv')
            
            if os.path.exists(file_path):
                try:
                    df = pd.read_csv(file_path)
                    if metric in df.columns:
                        mean_value = df[metric].mean()
                        if INCLUDE_STD_ERROR:
                            std_error = df[metric].std()
                            value_with_error = (mean_value, std_error)
                        else:
                            value_with_error = mean_value
                        
                        if dataset_key not in results:
                            results[dataset_key] = {}
                        results[dataset_key][model] = value_with_error
                    else:
                        pass
                except Exception as e:
                    pass
            else:
                pass
    
    return results

def create_table(results, metric):
    """Create the results table."""
    # Get all unique datasets and models
    datasets = list(results.keys())
    models = MODELS_TO_INCLUDE
    
    # Create empty table
    table_data = []
    for model in models:
        row = [MODEL_NAMES.get(model, model)]
        for dataset in datasets:
            value = results.get(dataset, {}).get(model, np.nan)
            if INCLUDE_STD_ERROR and isinstance(value, tuple):
                # Extract just the mean value for table creation
                value = value[0]
            row.append(value)
        table_data.append(row)
    
    # Create DataFrame
    df_table = pd.DataFrame(table_data, columns=['Model'] + datasets)
    df_table = df_table.set_index('Model')
    
    return df_table

def create_latex_table(table, metric, results):
    """Create LaTeX table."""
    # Clean dataset names for display
    clean_datasets = []
    for dataset in table.columns:
        # Use the full project_dataset key for lookup
        clean_name = DATASET_NAMES.get(dataset, dataset.replace('_', ' ').title())
        clean_datasets.append(clean_name)
    
    # Generate LaTeX
    latex_lines = []
    latex_lines.append("\\begin{table}[htbp]")
    latex_lines.append("\\centering")
    latex_lines.append("\\tiny")
    latex_lines.append(f"\\caption{{{metric} across different models and datasets}}")
    latex_lines.append(f"\\label{{tab:{metric.lower().replace(' ', '_')}}}")
    latex_lines.append("\\begin{tabular}{l" + "c" * len(table.columns) + "}")
    latex_lines.append("\\toprule")
    
    # Header
    header = "Model & " + " & ".join(clean_datasets) + " \\\\"
    latex_lines.append(header)
    latex_lines.append("\\midrule")
    
    # Find maximum and second maximum values in each column for formatting
    max_values = table.max()
    second_max_values = {}
    
    for col_name in table.columns:
        col_values = table[col_name].dropna()
        if len(col_values) >= 2:
            sorted_values = col_values.sort_values(ascending=False)
            second_max_values[col_name] = sorted_values.iloc[1]
        else:
            second_max_values[col_name] = None
    
    # Data rows
    for idx, row in table.iterrows():
        row_data = [idx]
        for col_idx, (col_name, val) in enumerate(row.items()):
            if pd.isna(val):
                row_data.append("--")
            else:
                # Format the value
                if metric in ["Top1 agreement"]:
                    formatted_val = f"{val*100:.2f}"
                else:
                    formatted_val = f"{val:.4f}"
                
                # Add standard error if enabled
                if INCLUDE_STD_ERROR:
                    # Get the original value with error from results
                    dataset_name = table.columns[col_idx]
                    # Map display name back to internal model name
                    model_name = None
                    for internal_name, display_name in MODEL_NAMES.items():
                        if display_name == idx:
                            model_name = internal_name
                            break
                    
                    if model_name and dataset_name in results and model_name in results[dataset_name]:
                        original_value = results[dataset_name][model_name]
                        if isinstance(original_value, tuple):
                            mean_val, std_err = original_value
                            if metric in ["Top1 agreement"]:
                                std_err_formatted = f"{std_err*100:.2f}"
                            else:
                                std_err_formatted = f"{std_err:.4f}"
                            formatted_val = f"{formatted_val} $\\pm$ {std_err_formatted}"
                
                # Apply formatting based on ranking first
                if val == max_values[col_name]:
                    # Bold for best value
                    formatted_val = f"\\textbf{{{formatted_val}}}"
                elif second_max_values[col_name] is not None and val == second_max_values[col_name]:
                    # Underline for second best value
                    formatted_val = f"\\underline{{{formatted_val}}}"
                
                # Add leading spaces with \phantom for numbers below 1 and 0.1 (after formatting)
                if val < 0.1:
                    formatted_val = f"\\phantom{{00}}{formatted_val}"
                elif val < 1:
                    formatted_val = f"\\phantom{{0}}{formatted_val}"
                
                row_data.append(formatted_val)
        
        row_str = " & ".join(row_data) + " \\\\"
        latex_lines.append(row_str)
    
    latex_lines.append("\\bottomrule")
    latex_lines.append("\\end{tabular}")
    latex_lines.append("\\end{table}")
    
    return "\n".join(latex_lines)

def main():
    parser = argparse.ArgumentParser(description='Create results table from bootstrap results')
    parser.add_argument('results_dir', help='Results directory (e.g., results/)')
    
    args = parser.parse_args()
    
    # Use the selected metric from hyperparameters
    metric = SELECTED_METRIC
    output_file = 'output.txt'
    
    # Load results
    results = load_results(args.results_dir, metric)
    
    if not results:
        print("No results found")
        return
    
    # Create table
    table = create_table(results, metric)
    
    # Save table as LaTeX
    latex_content = create_latex_table(table, metric, results)
    with open(output_file, 'w') as f:
        f.write(latex_content)

if __name__ == "__main__":
    main()